

# ################################################
# Unified: reparam + EL
# ################################################

optimize_unified_reparam <- function(dat, idx_test, fmla_f, fmla_m, px, beta_start){
 
 # Prepare the data
 Xm = as.matrix(model.matrix(fmla_m, data=model.frame(dat, na.action = NULL)))
 Xf = as.matrix(model.matrix(fmla_f, data=model.frame(dat, na.action = NULL)))
 
 # Initial values for beta
 if (length(beta_start) == 0){
  beta_start = rep(0.1, ncol(Xm) + ncol(Xf) + 1) # added 1 for w_0 
  names(beta_start) = c(colnames(Xm), colnames(Xf), "w0")
 }
 
 # Define the negative log likelihood function   
 eval_f <- function(beta, dat, Xm, px){
  # cat("1 \n")
  n = nrow(dat)
  p = length(beta)
  beta_m = beta[1:ncol(Xm)]
  beta_f = beta[(ncol(Xm)+1):(p-1)]
  w0 = beta[p]
  wa = 0
  names(beta_m) = colnames(Xm)
  names(beta_f) = colnames(Xf)
  names(w0) = c("w0")
  names(wa) = c("wa")
  beta_y = c(beta_f, w0, wa)
  M = dat$M
  Y = dat$Y
  Y_hat = estimate_Y(dat, beta_y, beta_m, px) 
  Y[idx_test] = Y_hat[idx_test]
  f = sum(-M*log(1+exp(-Xm%*%beta_m))-(1-M)*log(1+exp(Xm%*%beta_m))) + sum(-(Y - Y_hat)^2/2)
  return(-f/n)
 }
 
 # Solve the optimization problem
 mle_res = nloptr(x0=beta_start, 
              eval_f=eval_f, 
              opts = list("algorithm"="NLOPT_LN_COBYLA","xtol_rel"=1.0e-3, "maxeval"=5000),
              dat=dat, Xm=Xm, px=px)
 
 # Returnt the parameters
 beta = mle_res$solution
 p = length(beta)
 beta_m = beta[1:ncol(Xm)]
 beta_y = c(beta[(ncol(Xm)+1):p], 0)
 
 names(beta_m) = colnames(Xm)
 names(beta_y) = c(colnames(Xf), "w0", "wa")
 
 return(list(beta_m = beta_m, 
             beta_y = beta_y))
}

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

get_m_unified <- function(beta, dat, px){
 
 n = nrow(dat)
 
 beta_m = beta$beta_m
 beta_y = beta$beta_y
 
 dat_a0m0 = process_data(dat, a = 0, m = 0)
 dat_a0m1 = process_data(dat, a = 0, m = 1)
 dat_a1m0 = process_data(dat, a = 1, m = 0)
 dat_a1m1 = process_data(dat, a = 1, m = 1)
 
 # M model 
 idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
 # idx_m = c(1, match(colnames(model.frame(fmla_m))[-1], colnames(dat)))
 p_m1a0 = 1/(1 + exp(-dat_a0m1[, idx_m]%*%beta_m))
 p_m0a0 = 1 - p_m1a0
 
 # Y model 
 p = length(beta_y)
 beta_f = beta_y[1:(p-2)]
 w0 = beta_y[p-1]
 wa = beta_y[p]
 idx_f = match(attributes(beta_f)$names, colnames(dat)) 
 # idx_f = c(1, match(colnames(model.frame(fmla_f))[-1], colnames(dat)))
 f_m1a1c = dat_a1m1[, idx_f]%*%beta_f
 f_m0a1c = dat_a1m0[, idx_f]%*%beta_f
 f_m1a0c = dat_a0m1[, idx_f]%*%beta_f
 f_m0a0c = dat_a0m0[, idx_f]%*%beta_f
 # E[Y | A = 1, M = 1, C]
 y_a1m1 = f_m1a1c - sum(px*( f_m1a1c*p_m1a0 + f_m0a1c*p_m0a0 )) + w0 + wa
 # E[Y | A = 1, M = 0, C]
 y_a1m0 = f_m0a1c - sum(px*( f_m1a1c*p_m1a0 + f_m0a1c*p_m0a0 )) + w0 + wa
 # E[Y | A = 0, M = 1, C]
 y_a0m1 = f_m1a0c - sum(px*( f_m1a0c*p_m1a0 + f_m0a0c*p_m0a0 )) + w0
 # E[Y | A = 0, M = 0, C]
 y_a0m0 = f_m0a0c - sum(px*( f_m1a0c*p_m1a0 + f_m0a0c*p_m0a0 )) + w0

 m = ((y_a1m0 - y_a0m0)*p_m0a0 + (y_a1m1 - y_a0m1)*p_m1a0) 
 
 return(m)
}


# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


optimize_unified <- function(dat, idx_test, fmla_f, fmla_m, px, beta_start, opt){
 
 n = nrow(dat)
 
 max_iter = opt$max_iter
 threshold = opt$threshold
 
 # Prepare the data
 Xm = as.matrix(model.matrix(fmla_m, data=model.frame(dat, na.action = NULL)))
 Xf = as.matrix(model.matrix(fmla_f, data=model.frame(dat, na.action = NULL)))
 
 # Initial values for beta
 if (length(beta_start) == 0){
  beta_start = rep(0.1, ncol(Xm) + ncol(Xf) + 2) # added 1 for w_0
  names(beta_start) = c(colnames(Xm), colnames(Xf), "w0", "wa")
  beta_start = list(beta_m = beta_start[1:ncol(Xm)], 
                    beta_y = beta_start[(ncol(Xm) + 1):length(beta_start)])
 }
 
 for (i in 1:max_iter){
  
  cat("iter: ", i, "\n")
  
  m = get_m_unified(beta_start, dat, px)
  lambda = get_lambda(m, dat)
  pi = get_pi(m, lambda, dat) 
  
  beta_m_st = beta_start$beta_m
  beta_y_st = beta_start$beta_y[-length(beta_start$beta_y)]
  beta_start = c(beta_m_st, beta_y_st)
  
  beta_up = optimize_unified_reparam(dat, idx_test, fmla_f, fmla_m, pi, beta_start)
  beta_m_up = beta_up$beta_m
  beta_y_up = beta_up$beta_y[-length(beta_up$beta_y)]
  
  err_beta_m = sqrt(sum((beta_m_st - beta_m_up)^2))
  err_beta_y = sqrt(sum((beta_y_st - beta_y_up)^2))
  err_pi = sqrt(sum((pi - px)^2))
  
  # cat("err_pi", err_pi, "err_beta_m", err_beta_m, "err_beta_y", err_beta_y, "\n")
  cat("nde = ", sum(px*m), "\n")
  
  # if (sum(px*m) < 0.0001){
  # if (err_pi < threshold && err_beta_m < threshold && err_beta_y < threshold){
  if (err_pi < 0.001 && err_beta_m < 0.005 && err_beta_y < 0.005){
  # if (err_pi < 0.0007){
   break 
  }else{
   beta_start = beta_up
   px = pi
  }
 }
 
 beta = beta_up
 beta_m = beta$beta_m
 beta_y = beta$beta_y
 
 px = pi
 m = get_m_unified(beta, dat, px)
 nde = sum(px*m)
 
 # Compute log likelihood
 Yhat = estimate_Y(dat, beta_y, beta_m, px)
 Y = dat$Y
 Y[idx_test] = Yhat[idx_test]
 p_Y = dnorm(Y, Yhat, 1)
 
 p_M1 = 1/(1+exp(-Xm%*%beta_m))
 p_M = M*p_M1 + (1-M)*(1-p_M1)
 
 log_lik_YM = sum(log(p_M) + log(p_Y))

 return(list(beta_m=beta_m,
             beta_y=beta_y,
             px=px, 
             mle=log_lik_YM, 
             Yhat=Yhat, 
             nde = nde))
}





